{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {},
  "cells": [
    {
      "metadata": {},
      "source": [
        "<td>\n",
        "   <a target=\"_blank\" href=\"https://labelbox.com\" ><img src=\"https://labelbox.com/blog/content/images/2021/02/logo-v4.svg\" width=256/></a>\n",
        "</td>"
      ],
      "cell_type": "markdown"
    },
    {
      "metadata": {},
      "source": [
        "<td>\n",
        "<a href=\"https://colab.research.google.com/github/Labelbox/labelbox-python/blob/master/examples/annotation_import/import_labeled_dataset_image.ipynb\" target=\"_blank\"><img\n",
        "src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"></a>\n",
        "</td>\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "<td>\n",
        "<a href=\"https://github.com/Labelbox/labelbox-python/blob/master/examples/annotation_import/import_labeled_dataset_image.ipynb\" target=\"_blank\"><img\n",
        "src=\"https://img.shields.io/badge/GitHub-100000?logo=github&logoColor=white\" alt=\"GitHub\"></a>\n",
        "</td>"
      ],
      "cell_type": "markdown"
    },
    {
      "metadata": {},
      "source": [
        "!pip install -q \"labelbox[data]\""
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "import labelbox as lb\n",
        "from labelbox.schema.data_row_metadata import DataRowMetadataField, DataRowMetadataKind\n",
        "import datetime\n",
        "import random\n",
        "import os\n",
        "import json\n",
        "from PIL import Image\n",
        "from labelbox.schema.ontology import OntologyBuilder, Tool\n",
        "import requests\n",
        "from tqdm.notebook import tqdm\n",
        "import uuid\n",
        "from labelbox.data.annotation_types import Label, ImageData, ObjectAnnotation, Rectangle, Point"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "# Setup Labelbox client"
      ],
      "cell_type": "markdown"
    },
    {
      "metadata": {},
      "source": [
        "# Initialize the Labelbox client\n",
        "API_KEY = \"\" # Place API key\n",
        "client = lb.Client(API_KEY)"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "# Download a public dataset\n"
      ],
      "cell_type": "markdown"
    },
    {
      "metadata": {},
      "source": [
        "# Function to download files\n",
        "def download_files(filemap):\n",
        "    path, uri = filemap\n",
        "    if not os.path.exists(path):\n",
        "        response = requests.get(uri, stream=True)\n",
        "        with open(path, 'wb') as f:\n",
        "            for chunk in response.iter_content(chunk_size=8192):\n",
        "                f.write(chunk)\n",
        "    return path"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "# Download data rows and annotations\n",
        "DATA_ROWS_URL = \"https://storage.googleapis.com/labelbox-datasets/VHR_geospatial/geospatial_datarows.json\"\n",
        "ANNOTATIONS_URL = \"https://storage.googleapis.com/labelbox-datasets/VHR_geospatial/geospatial_annotations.json\"\n",
        "download_files((\"data_rows.json\", DATA_ROWS_URL))\n",
        "download_files((\"annotations.json\", ANNOTATIONS_URL))"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "# Load data rows and annotations\n",
        "with open('data_rows.json') as fp:\n",
        "    data_rows = json.load(fp)\n",
        "with open('annotations.json') as fp:\n",
        "    annotations = json.load(fp)"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "# Create a dataset"
      ],
      "cell_type": "markdown"
    },
    {
      "metadata": {},
      "source": [
        "# Create a new dataset\n",
        "dataset = client.create_dataset(name=\"Geospatial vessel detection\")\n",
        "print(f\"Created dataset with ID: {dataset.uid}\")"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "# Import Data Rows with Metadata"
      ],
      "cell_type": "markdown"
    },
    {
      "metadata": {},
      "source": [
        "# Here is an example of adding two metadata fields to your Data Rows: a \"captureDateTime\" field with datetime value, and a \"tag\" field with string value\n",
        "metadata_ontology = client.get_data_row_metadata_ontology()\n",
        "datetime_schema_id = metadata_ontology.reserved_by_name[\"captureDateTime\"].uid\n",
        "tag_schema_id = metadata_ontology.reserved_by_name[\"tag\"].uid\n",
        "tag_items = [\"WorldView-1\", \"WorldView-2\", \"WorldView-3\", \"WorldView-4\"]\n",
        "\n",
        "for datarow in tqdm(data_rows):\n",
        "    dt = datetime.datetime.utcnow() + datetime.timedelta(days=random.random()*30) # this is random datetime value\n",
        "    tag_item = random.choice(tag_items) # this is a random tag value\n",
        "\n",
        "    # Option 1: Specify metadata with a list of DataRowMetadataField. This is the recommended option since it comes with validation for metadata fields.\n",
        "    metadata_fields = [\n",
        "                       DataRowMetadataField(schema_id=datetime_schema_id, value=dt),\n",
        "                       DataRowMetadataField(schema_id=tag_schema_id, value=tag_item)\n",
        "                       ]\n",
        "\n",
        "    # Option 2: Uncomment to try. Alternatively, you can specify the metadata fields with dictionary format without declaring the DataRowMetadataField objects. It is equivalent to Option 1.\n",
        "    # metadata_fields = [\n",
        "    #                    {\"schema_id\": datetime_schema_id, \"value\": dt},\n",
        "    #                    {\"schema_id\": tag_schema_id, \"value\": tag_item}\n",
        "    #                    ]\n",
        "\n",
        "    datarow[\"metadata_fields\"] = metadata_fields"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "task = dataset.create_data_rows(data_rows)\n",
        "task.wait_till_done()\n",
        "print(f\"Failed data rows: {task.failed_data_rows}\")\n",
        "print(f\"Errors: {task.errors}\")\n",
        "\n",
        "if task.errors:\n",
        "    for error in task.errors:\n",
        "        if 'Duplicate global key' in error['message'] and dataset.row_count == 0:\n",
        "            # If the global key already  exists in the workspace the dataset will be created empty, so we can delete it.\n",
        "            print(f\"Deleting empty dataset: {dataset}\")\n",
        "            dataset.delete()"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "Examine a Data Row"
      ],
      "cell_type": "markdown"
    },
    {
      "metadata": {},
      "source": [
        "datarow = next(dataset.data_rows())\n",
        "print(datarow)"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "# Setup a labeling project"
      ],
      "cell_type": "markdown"
    },
    {
      "metadata": {},
      "source": [
        "# Initialize the OntologyBuilder\n",
        "ontology_builder = OntologyBuilder()\n",
        "\n",
        "# Assuming 'annotations' is defined and contains the necessary data\n",
        "for category in annotations['categories']:\n",
        "    print(category['name'])\n",
        "    # Add tools to the ontology builder\n",
        "    ontology_builder.add_tool(Tool(tool=Tool.Type.BBOX, name=category['name']))\n",
        "\n",
        "# Create the ontology in Labelbox\n",
        "ontology = client.create_ontology(\"Vessel Detection Ontology\",\n",
        "                                  ontology_builder.asdict(),\n",
        "                                  media_type=lb.MediaType.Image)\n",
        "print(f\"Created ontology with ID: {ontology.uid}\")\n",
        "\n",
        "# Create a project and set up the ontology\n",
        "project = client.create_project(name=\"Vessel Detection\", media_type=lb.MediaType.Image)\n",
        "project.setup_editor(ontology=ontology)\n",
        "print(f\"Created project with ID: {project.uid}\")"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "# Send a batch of data rows to the project"
      ],
      "cell_type": "markdown"
    },
    {
      "metadata": {},
      "source": [
        "client.enable_experimental = True\n",
        "\n",
        "# Minimal ExportV2 parameters focused solely on data row IDs\n",
        "export_params = {\n",
        "    \"data_row_details\": True  # Only export data row details \n",
        "}\n",
        "\n",
        "# Initiate the streamable export task from catalog\n",
        "dataset = client.get_dataset(dataset.uid)  # Update with the actual dataset ID\n",
        "export_task = dataset.export(params=export_params)\n",
        "export_task.wait_till_done()\n",
        "print(export_task)\n",
        "\n",
        "data_rows = []\n",
        "\n",
        "# Callback used for JSON Converter to correctly collect data row IDs\n",
        "def json_stream_handler(output: lb.JsonConverterOutput):\n",
        "    # Parse the JSON string to access the data\n",
        "    data = json.loads(output.json_str)\n",
        "\n",
        "    # Correctly extract and append DataRow ID\n",
        "    if 'data_row' in data and 'id' in data['data_row']:\n",
        "        data_rows.append(data['data_row']['id'])\n",
        "\n",
        "# Process the stream if there are results\n",
        "if export_task.has_result():\n",
        "    export_task.get_stream(\n",
        "        converter=lb.JsonConverter(),\n",
        "        stream_type=lb.StreamType.RESULT\n",
        "    ).start(stream_handler=json_stream_handler)\n",
        "\n",
        "# Randomly select 200 Data Rows (or fewer if the dataset has less than 200 data rows)\n",
        "sampled_data_rows = random.sample(data_rows, min(len(data_rows), 200))\n",
        "\n",
        "# Create a new batch in the project and add the sampled data rows\n",
        "batch = project.create_batch(\n",
        "    \"Initial batch\",  # name of the batch\n",
        "    sampled_data_rows,  # list of Data Rows\n",
        "    1  # priority between 1-5\n",
        ")\n",
        "print(f\"Created batch with ID: {batch.uid}\")"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "# Create annotations payload"
      ],
      "cell_type": "markdown"
    },
    {
      "metadata": {},
      "source": [
        "\n",
        "# Set export parameters focused on data row details\n",
        "export_params = {\n",
        "    \"data_row_details\": True,  # Only export data row details\n",
        "    \"batch_ids\": [batch.uid],  # Optional: Include batch ids to filter by specific batches\n",
        "}\n",
        "\n",
        "# Initialize the streamable export task from project\n",
        "export_task = project.export(params=export_params)\n",
        "export_task.wait_till_done()\n",
        "\n",
        "data_rows = []\n",
        "\n",
        "def json_stream_handler(output: lb.JsonConverterOutput):\n",
        "  data_row = json.loads(output.json_str)\n",
        "  data_rows.append(data_row)\n",
        "\n",
        "\n",
        "if export_task.has_errors():\n",
        "  export_task.get_stream(\n",
        "  \n",
        "  converter=lb.JsonConverter(),\n",
        "  stream_type=lb.StreamType.ERRORS\n",
        "  ).start(stream_handler=lambda error: print(error))\n",
        "\n",
        "if export_task.has_result():\n",
        "  export_json = export_task.get_stream(\n",
        "    converter=lb.JsonConverter(),\n",
        "    stream_type=lb.StreamType.RESULT\n",
        "  ).start(stream_handler=json_stream_handler)\n",
        "\n",
        "labels = []\n",
        "for datarow in data_rows:\n",
        "    annotations_list = []\n",
        "    # Access the 'data_row' dictionary first\n",
        "    data_row_dict = datarow['data_row']\n",
        "    folder = data_row_dict['external_id'].split(\"/\")[0]\n",
        "    id = data_row_dict['external_id'].split(\"/\")[1]\n",
        "    \n",
        "    if folder == \"positive_image_set\":\n",
        "        for image in annotations['images']:\n",
        "            if image['file_name'] == id:\n",
        "                for annotation in annotations['annotations']:\n",
        "                    if annotation['image_id'] == image['id']:\n",
        "                        bbox = annotation['bbox']\n",
        "                        category_id = annotation['category_id'] - 1\n",
        "                        class_name = None\n",
        "                        ontology = ontology_builder.asdict()  # Get the ontology dictionary\n",
        "                        for category in ontology['tools']:\n",
        "                            if category['name'] == annotations['categories'][category_id]['name']:\n",
        "                                class_name = category['name']\n",
        "                                break\n",
        "                        if class_name:\n",
        "                            annotations_list.append(ObjectAnnotation(\n",
        "                                name=class_name,\n",
        "                                value=Rectangle(start=Point(x=bbox[0], y=bbox[1]), end=Point(x=bbox[2]+bbox[0], y=bbox[3]+bbox[1]))\n",
        "                            ))\n",
        "    image_data = ImageData(uid=data_row_dict['id'])\n",
        "    labels.append(Label(data=image_data, annotations=annotations_list))"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "upload_job = lb.LabelImport.create_from_objects(\n",
        "    client=client,\n",
        "    project_id=project.uid,\n",
        "    name=f\"label_import_job_{str(uuid.uuid4())}\",\n",
        "    labels=labels\n",
        ")\n",
        "\n",
        "# Wait for the upload to finish and print the results\n",
        "upload_job.wait_until_done()\n",
        "\n",
        "print(f\"Errors: {upload_job.errors}\")\n",
        "print(f\"Status of uploads: {upload_job.statuses}\")"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    }
  ]
}